################################## Results analysis - Real estate market #########################################
#
# Authors:
#
# Sandro Heiniger
# Swiss Institute for Empirical Economic Research
# University St. Gallen
# sandro.heiniger@unisg.ch
#
#### 



# Settings ----------------------------------------------------------------

setwd(set_path)

library(install.load)
install_load("tidyverse","ggplot2","tools","patchwork","viridis")


# one folder for all the specifications ---------------------------

create_groups=function(sequence){
  sequence=data.frame("pop"=sequence, "groups"=rep(2,length(sequence)))
  sequence[1,"groups"]=1
  sequence[nrow(sequence),"groups"]=3
  
  while(sum(sequence$groups==2)>1){
    low_sequence=sequence
    low_sequence[min(which(low_sequence$groups==2)),"groups"]=1
    high_sequence=sequence
    high_sequence[max(which(high_sequence$groups==2)),"groups"]=3
    improvement=F
    if(var(low_sequence %>% group_by(groups) %>% summarise(total_pop=sum(pop)) %>% pull(total_pop))<
       var(sequence %>% group_by(groups) %>% summarise(total_pop=sum(pop)) %>% pull(total_pop))){
      sequence=low_sequence
      improvement=T
    }
    if(var(high_sequence %>% group_by(groups) %>% summarise(total_pop=sum(pop)) %>% pull(total_pop))<
       var(sequence %>% group_by(groups) %>% summarise(total_pop=sum(pop)) %>% pull(total_pop))){
      sequence=high_sequence
      improvement=T
    }
    
    if(!improvement){
      break
    }
  }
  return(sequence$groups)
}

# load estimation data
estimation_data=read.csv("path_to_file")
RE21_price_data_estimation= read.csv("path_to_file")

# settings

path_to_results="set_path"

folder="set_folder"

load_summary_file="summary"

load_treatments=c("Covid","KA_di")
load_types=c("rent","sale")
load_channels=c("officelogp","residentiallogp","retaillogp")
load_models=c("fat","lean")
load_weighting=c("","WEIGHT")
load_mate=c("") #"MATE"

plot_ate=T
plot_gates=T
constant_names_for_gates=F

load_gate_cont=c("apt_share_owned_private_invest",
                 'avg_apt_per_res_fac',
                 'avg_apt_size',
                 'ft_empl_wp_share_sector_tertiary_I_R',
                 'ft_empl_wp_share_sector_tertiary_K',
                 'HO_occ_index',
                 'log_income_pp',
                 'priv_alle_tech_200',
                 'Q4_2019_rent_officelogp',
                 'Q4_2019_rent_residentiallogp',
                 'Q4_2019_rent_retaillogp')
load_gate_disc=c('urbanization')
load_gate_quarter="Q1_2021"
load_gate_compare="3vs1"

gate_violin=c("priv_alle_tech_200","log_income_pp","urbanization",'HO_occ_index',"avg_apt_per_res_fac")

show_categories=c("3_1")
#show_categories=c("3_2","2_1")

generate_gate_plots=function(estimation_data, path_to_results, folder, load_summary_file, load_treatments,load_types, 
                             load_channels, load_models, load_gate_cont,load_gate_disc,
                             load_gate_quarter,load_gate_compare, load_weighting, show_categories,
                             plot_ate, plot_gates, constant_names_for_gates){
  dir.create(paste0("./Estimation_plots/",folder), showWarnings = FALSE)
  categories_addon=paste0("_",paste(show_categories,collapse = "_and_"))
  
  # generate all combinations
  outcome_names=apply(expand.grid(load_treatments, load_types, load_channels, load_models, load_weighting), 1, paste, collapse="_")
  outcome_names=gsub('_$', '', outcome_names)
  
  if(plot_ate){
    ate_long=data.frame()
    for(current_weight in load_weighting){
      
      if(current_weight!=""){current_weight=paste0("_",current_weight)}
      
      # read the summary of the ATEs
      ate_summary=read.table(file=paste0(path_to_results,"/",folder,"/",load_summary_file,current_weight,".txt"),
                             stringsAsFactors = F,
                             header=F,
                             sep="",
                             fill=NA,
                             strip.white = T,
                             blank.lines.skip = T,
                             flush=T,
                             col.names = c("outcome","effect_2_1","se_2_1","effect_3_1","se_3_1","effect_3_2","se_3_2"))
      
      ate_summary=suppressWarnings(ate_summary %>% 
                                     mutate(across(2:ncol(ate_summary),as.numeric)) %>% 
                                     mutate(outcome=str_replace(outcome,":","")) %>% 
                                     filter(outcome %in% outcome_names) %>% 
                                     na.omit())
      
      ate_long=bind_rows(ate_long,
                         bind_rows(ate_summary %>% 
                                     select(outcome, effect_2_1, se_2_1) %>%
                                     `colnames<-`(c("outcome", "effect", "se")) %>% 
                                     mutate(group="2_1"),
                                   ate_summary %>% 
                                     select(outcome, effect_3_2, se_3_2) %>%
                                     `colnames<-`(c("outcome", "effect", "se")) %>% 
                                     mutate(group="3_2"),
                                   ate_summary %>% 
                                     select(outcome, effect_3_1, se_3_1) %>%
                                     `colnames<-`(c("outcome", "effect", "se")) %>% 
                                     mutate(group="3_1")) %>%
                           mutate(outcome=str_replace_all(str_replace_all(outcome,"di_",""),"logp","")) %>%
                           mutate(Treatment=str_replace(str_split(outcome,"_",simplify = T)[,1],pattern = "KA","STW"),
                                  Channel=toTitleCase(str_split(outcome,"_",simplify = T)[,2]),
                                  Type=toTitleCase(str_split(outcome,"_",simplify = T)[,3]),
                                  Model=str_replace(str_replace(str_split(outcome,"_",simplify = T)[,4],"fat","FULL"),"lean","EMH"),
                                  Weighting=ifelse(current_weight=="","unweighted","weighted")))
    }
    
    ate_long=ate_long %>% group_by(group, Treatment, Channel, Type, Model, Weighting) %>% 
      summarise(effect=mean(effect),
                se=mean(se),
                .groups = "drop")
    
    ate_long$group=factor(ate_long$group, levels = c("2_1","3_2","3_1"), ordered = T)
    ate_long$Treatment=factor(ate_long$Treatment, levels = c("Covid","STW"), ordered = T)
    ate_long$Channel=factor(ate_long$Channel, levels = c("Rent","Sale"), ordered = T)
    ate_long$Type=factor(ate_long$Type, levels = c("Residential","Retail","Office"), ordered = T)
    ate_long$Model=factor(ate_long$Model, levels = c("EMH","FULL"), ordered = T)
    ate_long$pvalue=round(pnorm(-abs(ate_long$effect), sd=ate_long$se)*2,3)
    
    # Overview the preferred specification
    plot_ate_only_EMH_weight=ggplot(
      ate_long %>% filter(Weighting=="weighted", Model=="EMH", group %in% show_categories))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se), size=1, color="deepskyblue4", width=0.5) + 
      labs(y="Average treatment effect (ATE)",
           x="Between treatment dimension")+
      geom_point(aes(x=group, y=effect), size=4, shape=21, fill="white") +
      facet_grid(~Type+Treatment+Channel,scales='free') +
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme_light(base_size = 14)+
      theme(strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    ggsave(paste0("plot_ate_only_EMH_weight",categories_addon,".pdf"),
           plot_ate_only_EMH_weight,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    # Compare Channels
    
    plot_ate_by_channel_only_EMH_weight=ggplot(
      ate_long %>% filter(Weighting=="weighted", Model=="EMH", group %in% show_categories))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Channel), size=1, width=0.5, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Channel), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_channel_only_EMH_weight=plot_ate_by_channel_only_EMH_weight+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_channel_only_EMH_weight=plot_ate_by_channel_only_EMH_weight+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_channel_only_EMH_weight",categories_addon,".pdf"),
           plot_ate_by_channel_only_EMH_weight,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    plot_ate_by_channel_only_EMH_unweight=ggplot(
      ate_long %>% filter(Weighting=="unweighted", Model=="EMH", group %in% show_categories))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Channel), size=1, width=0.5, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Channel), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_channel_only_EMH_unweight=plot_ate_by_channel_only_EMH_unweight+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_channel_only_EMH_unweight=plot_ate_by_channel_only_EMH_unweight+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_channel_only_EMH_unweight",categories_addon,".pdf"),
           plot_ate_by_channel_only_EMH_unweight,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    plot_ate_by_channel_only_FULL_weight=ggplot(
      ate_long %>% filter(Weighting=="weighted", Model=="FULL", group %in% show_categories))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Channel), size=1, width=0.5, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Channel), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_channel_only_FULL_weight=plot_ate_by_channel_only_FULL_weight+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_channel_only_FULL_weight=plot_ate_by_channel_only_FULL_weight+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_channel_only_FULL_weight",categories_addon,".pdf"),
           plot_ate_by_channel_only_FULL_weight,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    # Only one type
    plot_ate_by_channel_only_EMH_weight_residential=ggplot(
      ate_long %>% filter(Weighting=="weighted", Model=="EMH", group %in% show_categories, Type=="Residential"))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Channel), size=1, width=0.3, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Channel), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_channel_only_EMH_weight_residential=plot_ate_by_channel_only_EMH_weight_residential+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_channel_only_EMH_weight_residential=plot_ate_by_channel_only_EMH_weight_residential+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_channel_only_EMH_weight_residential",categories_addon,".pdf"),
           plot_ate_by_channel_only_EMH_weight_residential,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    plot_ate_by_channel_only_EMH_weight_retail=ggplot(
      ate_long %>% filter(Weighting=="weighted", Model=="EMH", group %in% show_categories, Type=="Retail"))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Channel), size=1, width=0.3, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Channel), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_channel_only_EMH_weight_retail=plot_ate_by_channel_only_EMH_weight_retail+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_channel_only_EMH_weight_retail=plot_ate_by_channel_only_EMH_weight_retail+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_channel_only_EMH_weight_retail",categories_addon,".pdf"),
           plot_ate_by_channel_only_EMH_weight_retail,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    plot_ate_by_channel_only_EMH_weight_office=ggplot(
      ate_long %>% filter(Weighting=="weighted", Model=="EMH", group %in% show_categories, Type=="Office"))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Channel), size=1, width=0.3, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Channel), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_channel_only_EMH_weight_office=plot_ate_by_channel_only_EMH_weight_office+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_channel_only_EMH_weight_office=plot_ate_by_channel_only_EMH_weight_office+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_channel_only_EMH_weight_office",categories_addon,".pdf"),
           plot_ate_by_channel_only_EMH_weight_office,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    # By covariate set
    
    plot_ate_by_model_only_weight_rent=ggplot(
      ate_long %>% filter(Weighting=="weighted", group %in% show_categories,Channel=="Rent"))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Model), size=1, width=0.5, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Model), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_model_only_weight_rent=plot_ate_by_model_only_weight_rent+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_model_only_weight_rent=plot_ate_by_model_only_weight_rent+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_model_only_weight_rent",categories_addon,".pdf"),
           plot_ate_by_model_only_weight_rent,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    plot_ate_by_model_only_weight_sale=ggplot(
      ate_long %>% filter(Weighting=="weighted", group %in% show_categories,Channel=="Sale"))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Model), size=1, width=0.5, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Model), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_model_only_weight_sale=plot_ate_by_model_only_weight_sale+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_model_only_weight_sale=plot_ate_by_model_only_weight_sale+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_model_only_weight_sale",categories_addon,".pdf"),
           plot_ate_by_model_only_weight_sale,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    plot_ate_by_model_only_unweight_rent=ggplot(
      ate_long %>% filter(Weighting=="unweighted", group %in% show_categories,Channel=="Rent"))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Model), size=1, width=0.5, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Model), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_model_only_unweight_rent=plot_ate_by_model_only_unweight_rent+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_model_only_unweight_rent=plot_ate_by_model_only_unweight_rent+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_model_only_unweight_rent",categories_addon,".pdf"),
           plot_ate_by_model_only_unweight_rent,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    plot_ate_by_model_only_unweight_sale=ggplot(
      ate_long %>% filter(Weighting=="unweighted", group %in% show_categories,Channel=="Sale"))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Model), size=1, width=0.5, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Model), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_model_only_unweight_sale=plot_ate_by_model_only_unweight_sale+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_model_only_unweight_sale=plot_ate_by_model_only_unweight_sale+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_model_only_unweight_sale",categories_addon,".pdf"),
           plot_ate_by_model_only_unweight_sale,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    # By weighting
    
    plot_ate_by_weight_only_EMH_rent=ggplot(
      ate_long %>% filter(Model=="EMH", group %in% show_categories,Channel=="Rent"))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Weighting), size=1, width=0.5, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Weighting), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_weight_only_EMH_rent=plot_ate_by_weight_only_EMH_rent+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_weight_only_EMH_rent=plot_ate_by_weight_only_EMH_rent+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_weight_only_EMH_rent",categories_addon,".pdf"),
           plot_ate_by_weight_only_EMH_rent,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    plot_ate_by_weight_only_EMH_sale=ggplot(
      ate_long %>% filter(Model=="EMH", group %in% show_categories,Channel=="Sale"))+
      geom_hline(yintercept = 0, size=0.9, color="black") + 
      labs(y="Average treatment effect (ATE)",
           color="")+
      geom_errorbar(aes(x=group, ymin=effect-1.65*se, ymax=effect+1.65*se, color=Weighting), size=1, width=0.5, position=position_dodge(width=0.5)) + 
      geom_point(aes(x=group, y=effect, color=Weighting), fill="White", size=4, stroke=1.5, shape=21,position=position_dodge(width=0.5)) +
      facet_grid(~Type+Treatment,scales='free') +
      theme_light(base_size=14)+
      scale_color_viridis_d(option="viridis",direction=1,begin = 0.2, end=0.75)+
      theme(plot.title = element_text(hjust = 0.5),
            legend.position="bottom",
            legend.margin=margin(0,0,0,0),
            legend.box.margin=margin(-10,0,0,0),
            strip.text = element_text(size = 10, color = "black", lineheight=0))
    
    if(show_categories==c("3_1")){
      plot_ate_by_weight_only_EMH_sale=plot_ate_by_weight_only_EMH_sale+
        labs(x="Change from low to high incidence")+
        theme(axis.ticks.x = element_blank(),
              axis.text.x = element_blank())
    } else {
      plot_ate_by_weight_only_EMH_sale=plot_ate_by_weight_only_EMH_sale+
        labs(x="Between treatment dimension")
    }
    
    ggsave(paste0("plot_ate_by_weight_only_EMH_sale",categories_addon,".pdf"),
           plot_ate_by_weight_only_EMH_sale,
           device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
           path = paste0("./Estimation_plots/",folder,"/"))
    
    
    # Output for the latex table
    
    ate_summary_latex=bind_rows(data.frame("latex_out"="Treat. & Outc. & Spec. & 2-1 Eff. & 2-1 CI & 3-2 Eff. & 3-2 CI & 3-1 Eff. & 3-1 CI\\\\"),
                                ate_summary %>% 
                                  mutate(sign_2_1=(sign(effect_2_1-1.65*se_2_1)==sign(effect_2_1+1.65*se_2_1)),
                                         sign_3_1=(sign(effect_3_1-1.65*se_3_1)==sign(effect_3_1+1.65*se_3_1)),
                                         sign_3_2=(sign(effect_3_2-1.65*se_3_2)==sign(effect_3_2+1.65*se_3_2))) %>%
                                  mutate(outcome_latex=str_replace_all(str_replace_all(str_replace_all(outcome,"di_",""),"_"," & "),"logp",""),
                                         CI_2_1=paste0("[",format(round(effect_2_1-1.65*se_2_1,3),digits=3),",",format(round(effect_2_1+1.65*se_2_1,3),digits=3),"]"),
                                         CI_3_2=paste0("[",format(round(effect_3_2-1.65*se_3_2,3),digits=3),",",format(round(effect_3_2+1.65*se_3_2,3),digits=3),"]"),
                                         CI_3_1=paste0("[",format(round(effect_3_1-1.65*se_3_1,3),digits=3),",",format(round(effect_3_1+1.65*se_3_1,3),digits=3),"]")) %>% 
                                  mutate(effect_2_1=ifelse(sign_2_1,
                                                           paste0("\\textbf{",format(round(effect_2_1,3),digits=3),"}"),
                                                           format(round(effect_2_1,3),digits=3)),
                                         effect_3_1=ifelse(sign_3_1,
                                                           paste0("\\textbf{",format(round(effect_3_1,3),digits=3),"}"),
                                                           format(round(effect_3_1,3),digits=3)),
                                         effect_3_2=ifelse(sign_3_2,
                                                           paste0("\\textbf{",format(round(effect_3_2,3),digits=3),"}"),
                                                           format(round(effect_3_2,3),digits=3)),
                                         CI_2_1=ifelse(sign_2_1,paste0("\\textbf{",CI_2_1,"}"),CI_2_1),
                                         CI_3_1=ifelse(sign_3_1,paste0("\\textbf{",CI_3_1,"}"),CI_3_1),
                                         CI_3_2=ifelse(sign_3_2,paste0("\\textbf{",CI_3_2,"}"),CI_3_2)
                                  ) %>%
                                  mutate(latex_out=paste(paste(outcome_latex, 
                                                               str_remove_all(effect_2_1," "), 
                                                               str_remove_all(CI_2_1," "),  
                                                               str_remove_all(effect_3_2," "),  
                                                               str_remove_all(CI_3_2," "),  
                                                               str_remove_all(effect_3_1," "), 
                                                               str_remove_all(CI_3_1," "), 
                                                               sep=" & "),
                                                         " \\\\",sep="")) %>%
                                  select(latex_out))
    write.table(x=ate_summary_latex, 
                file=paste0("./Estimation_plots/",folder,"/ATE_summary.txt"),
                row.names = F, col.names = F, quote = F)
  }
  
  if(plot_gates){
    for(current_outcome in outcome_names){
      if(constant_names_for_gates){
        gates_prename=""
      } else{
        gates_prename=paste0(current_outcome,"_")
      }
      dir.create(paste0("./Estimation_plots/",folder, "/",str_replace_all(current_outcome,"_","-")), showWarnings = FALSE)
      for(current_gate in c(load_gate_cont,load_gate_disc)){
        # read the data for this gate
        if(str_split(current_outcome,"_")[[1]][1]=="Covid"){
          short_outcome=paste(str_split(current_outcome,"_")[[1]][2:3],collapse="_")
        } else {
          short_outcome=paste(str_split(current_outcome,"_")[[1]][3:4],collapse="_")
        }
        
        current_gate_data=read.csv(file=paste0(path_to_results,
                                               "/",
                                               folder,
                                               "/",
                                               "out",
                                               "_",
                                               current_outcome,
                                               "/",
                                               "fig_csv",
                                               "/",
                                               "GATE",
                                               load_mate,
                                               toupper(current_gate),
                                               "All",
                                               load_gate_quarter,
                                               "_",
                                               toupper(short_outcome),
                                               "LC",
                                               load_gate_compare,
                                               ".csv"))
        
        # transform to long format
        current_gate_data=bind_rows(current_gate_data %>% select(upper, effects, lower, z_values) %>% mutate(type="gate"),
                                    current_gate_data %>% select(ate, ate_upper, ate_lower, z_values)  %>% 
                                      rename(c("upper"="ate_upper", "effects"="ate", "lower"="ate_lower")) %>% mutate(type="ate"))
        
        # create plot
        ggplot_labels <- c(ate = "ATE & 90% CI", gate = "GATE & 90% CI")
        ggplot_breaks <- c("ate", "gate")
        gate_plot=ggplot(data=current_gate_data, aes(x=z_values)) + 
          geom_hline(yintercept = 0, size=0.9, color="black") +
          geom_line(aes(y=effects, color=type), size=1.1) + 
          geom_line(aes(y=upper, color=type), size=0.9) +
          geom_line(aes(y=lower, color=type), size=0.9) +
          geom_ribbon(aes(ymin = lower, ymax = upper, fill = type), alpha = 0.3)+
          scale_color_viridis_d(breaks = ggplot_breaks, labels = ggplot_labels,
                                option="viridis",direction=1,begin = 0.2, end=0.75)+
          scale_fill_viridis_d(breaks = ggplot_breaks, labels = ggplot_labels,
                               option="viridis",direction=1,begin = 0.2, end=0.75)+
          labs(#title = paste0("GATE by ",current_gate),
            x=current_gate, y="GATE", 
            color = NULL, 
            fill = NULL)+
          theme_light(base_size = 16)+
          theme(legend.position = c(0.15, 0.85),
                legend.background = element_rect(fill="grey95",
                                                 linewidth = 0.6, linetype="solid", 
                                                 colour ="grey40"),
                plot.title = element_text(hjust = 0.5))
        # if discrete, reduce the tick marks
        if(current_gate %in% load_gate_disc){
          gate_plot=gate_plot+scale_x_continuous(breaks=sort(unique(current_gate_data$z_values)))
        }
        
        # Save the plot
        ggsave(str_replace_all(paste0(gates_prename,current_gate,".pdf"),"_","-"),
               gate_plot,
               device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
               path = paste0("./Estimation_plots/",folder,"/",str_replace_all(current_outcome,"_","-"),"/"))
      }
      
      # read the sorted effects data
      current_iate_data=read.csv(file=paste0(path_to_results,"/",folder,
                                             "/out_",
                                             current_outcome,
                                             "/fig_csv/Sorted",
                                             load_gate_quarter,
                                             "_",
                                             toupper(short_outcome),
                                             "LC",
                                             load_gate_compare,
                                             "_iate.csv"))
      
      current_iate_data$z_values=1:nrow(current_iate_data)
      
      # transform to long format
      current_iate_data=bind_rows(current_iate_data %>% select(upper, effects, lower, z_values) %>% mutate(type="iate"),
                                  current_iate_data %>% select(ate, ate_u, ate_l, z_values)  %>% 
                                    rename(c("upper"="ate_u", "effects"="ate", "lower"="ate_l")) %>% mutate(type="ate"))
      
      # create plot
      ggplot_labels <- c(ate = "ATE & 90% CI", iate = "IATE & 90% CI")
      ggplot_breaks <- c("ate", "iate")
      gate_plot=ggplot(data=current_iate_data, aes(x=z_values/max(z_values))) + 
        geom_hline(yintercept = 0, size=0.9, color="black") +
        geom_line(aes(y=effects, color=type), size=1.1) + 
        geom_line(aes(y=upper, color=type), size=0.9) +
        geom_line(aes(y=lower, color=type), size=0.9) +
        geom_ribbon(aes(ymin = lower, ymax = upper, fill = type), alpha = 0.3)+
        scale_color_viridis_d(breaks = ggplot_breaks, labels = ggplot_labels,
                              option="viridis",direction=1,begin = 0.2, end=0.75)+
        scale_fill_viridis_d(breaks = ggplot_breaks, labels = ggplot_labels,
                             option="viridis",direction=1,begin = 0.2, end=0.75)+
        labs(x="Sorted observation by IATE", y="IATE", color = NULL, fill = NULL)+
      	scale_x_continuous(labels = scales::percent)+
        theme_light(base_size = 16)+
        theme(legend.position = c(0.15, 0.85),
              legend.background = element_rect(fill="grey95",
                                               size=0.6, linetype="solid", 
                                               colour ="grey40"))
      
      # Save the plot
      ggsave(paste0(gates_prename,"Sorted-effects.pdf"),
             gate_plot,
             device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
             path = paste0("./Estimation_plots/",folder,"/",str_replace_all(current_outcome,"_","-"),"/"))
      
      # Start with the violin plot
      if(grepl("resid", current_outcome, fixed = TRUE)){
        current_price="Q4_2019_rent_residential"
      } else if(grepl("office", current_outcome, fixed = TRUE)){
        current_price="Q4_2019_rent_office"
      } else if(grepl("retail", current_outcome, fixed = TRUE)){
        current_price="Q4_2019_rent_retail"
      }
      gate_violin_price=c(paste0(current_price,"logp"),gate_violin)
      
      price_data=RE21_price_data_estimation %>% 
        filter(RE21_price_id==current_price) %>% 
        select(ags_9, price)
      price_data[,paste0(current_price,"logp")]=log(price_data$price)
      estimation_data_red=left_join(estimation_data,
                                    price_data %>% select(-price),
                                    by="ags_9")
      
      estimation_data_red=estimation_data_red[,tolower(colnames(estimation_data_red)) %in% tolower(gate_violin_price)]
      estimation_data_red=pivot_longer(estimation_data_red, 
                                       cols=everything(),
                                       names_to="Variable")
      
      # read the outfile
      py_out=read.table(file=paste0(path_to_results,"/",folder,
                                    "/out_",
                                    str_replace_all(current_outcome,"-","_"),
                                    "/est",
                                    str_replace_all(current_outcome,"-","_"),
                                    ".txt"),
                        fill=T,
                        blank.lines.skip=T,
                        col.names=paste0("Col",1:10))
      
      Py_out_reduced=py_out[(last(which(py_out[,1]=="Covariates"))+3):(last(which(py_out[,1]=="Features"))-2),]
      Py_out_reduced=Py_out_reduced[tolower(Py_out_reduced[,1]) %in% tolower(gate_violin_price),]
      Py_out_reduced_wide=Py_out_reduced %>% slice(1:length(gate_violin_price))
      Py_out_reduced_wide= Py_out_reduced_wide[,which(Py_out_reduced_wide[1,]!="")]
      Py_out_reduced=Py_out_reduced[-(1:length(gate_violin_price)),]
      while(nrow(Py_out_reduced)>=length(gate_violin_price)){
        Py_out_reduced_wide=left_join(Py_out_reduced_wide,
                                      Py_out_reduced %>% slice(1:length(gate_violin_price)),
                                      by="Col1")
        Py_out_reduced_wide= Py_out_reduced_wide[,which(Py_out_reduced_wide[1,]!="")]
        Py_out_reduced=Py_out_reduced[-(1:length(gate_violin_price)),]
      }
      
      colnames(Py_out_reduced_wide)=c("Variable",paste0("Col",1:(ncol(Py_out_reduced_wide)-1)))
      Py_out_reduced_wide=Py_out_reduced_wide %>% mutate(across(starts_with("Col"),as.numeric))
      
      Py_out_groups=py_out[py_out[,1]=="group",]
      Py_out_groups$Col3=as.numeric(Py_out_groups$Col3)
      Py_out_groups$group <- create_groups(Py_out_groups$Col3) 

      Py_out_reduced_wide[,c("Cl-1")]=as.matrix(Py_out_reduced_wide[,which(Py_out_groups$group==1)+1]) %*% 
        as.vector(Py_out_groups$Col3[which(Py_out_groups$group==1)]) /
        sum(Py_out_groups$Col3[which(Py_out_groups$group==1)])
      
      Py_out_reduced_wide[,c("Cl-2")]=as.matrix(Py_out_reduced_wide[,which(Py_out_groups$group==2)+1]) %*% 
        as.vector(Py_out_groups$Col3[which(Py_out_groups$group==2)]) /
        sum(Py_out_groups$Col3[which(Py_out_groups$group==2)])
      
      Py_out_reduced_wide[,c("Cl-3")]=as.matrix(Py_out_reduced_wide[,which(Py_out_groups$group==3)+1]) %*% 
        as.vector(Py_out_groups$Col3[which(Py_out_groups$group==3)]) /
        sum(Py_out_groups$Col3[which(Py_out_groups$group==3)])
      
      Py_out_reduced_wide=Py_out_reduced_wide %>% arrange(Variable)
      Py_out_reduced_wide$Variable=as.factor(Py_out_reduced_wide$Variable)
      
      colors <- c("low" = viridis(n=3, begin=0.1, end=0.9)[3], 
                  "mid" = viridis(n=3, begin=0.1, end=0.9)[2], 
                  "high" = viridis(n=3, begin=0.1, end=0.9)[1])
      for(gate_variable in gate_violin_price){
        
        assign(paste0("Cl1",gate_variable),Py_out_reduced_wide %>% 
                 filter(tolower(Variable)==tolower(gate_variable)) %>% pull("Cl-1"))
        assign(paste0("Cl2",gate_variable),Py_out_reduced_wide %>% 
                 filter(tolower(Variable)==tolower(gate_variable)) %>% pull("Cl-2"))
        assign(paste0("Cl3",gate_variable),Py_out_reduced_wide %>% 
                 filter(tolower(Variable)==tolower(gate_variable)) %>% pull("Cl-3"))
        
        current_plot=ggplot() + 
          geom_violin(data=estimation_data_red %>% filter(Variable==gate_variable), 
                      aes(x=Variable,y=value),trim = F) +
          labs(x="",y="", color="Cluster")+
          scale_x_discrete(labels = function(x) str_wrap(ifelse(x=="priv_alle_tech_200","broadband coverage",str_replace_all(x, "_" , " ")),
                                                         width = 11))+
          theme_light(base_size = 14)+
          geom_hline(aes(yintercept = !!sym(paste0("Cl1",gate_variable)),
                         color="low"),
                     size=1.5)+
          geom_hline(aes(yintercept = !!sym(paste0("Cl2",gate_variable)),
                         color="mid"),
                     size=1.5)+
          geom_hline(aes(yintercept = !!sym(paste0("Cl3",gate_variable)),
                         color="high"),
                     size=1.5) +
          scale_color_manual(values = colors)+ 
          theme(legend.position = "bottom",
                axis.text = element_text(size=10))
        if(gate_variable=="urbanization"){
          current_plot=current_plot+ scale_y_reverse(breaks=c(1,2,3), labels=c("high","medium","low"))
        }
        
        if(gate_variable == gate_violin_price[1]){
          gate_cluster_plot=current_plot
        } else {
          gate_cluster_plot=(gate_cluster_plot+current_plot)
        }
        
      }
      gate_cluster_plot=gate_cluster_plot+ 
        plot_layout(ncol = length(gate_violin_price)) + 
        plot_layout(guides = "collect") &
        theme(legend.position = "bottom",
              legend.margin=margin(0,0,0,0),
              legend.box.margin=margin(-20,0,-10,0))
      
      ggsave(paste0(gates_prename,"gate_cluster_plot.pdf"),
             gate_cluster_plot,
             device = "pdf", width = 20, height = 10, units = "cm",dpi = 300, 
             path = paste0("./Estimation_plots/",folder,"/",str_replace_all(current_outcome,"_","-"),"/"))
    }
  }
  
}


generate_gate_plots(estimation_data, path_to_results, folder, load_summary_file, load_treatments, 
                    load_types, load_channels, load_models, load_gate_cont,load_gate_disc,
                    load_gate_quarter,load_gate_compare, load_weighting, show_categories,
                    plot_ate, plot_gates, constant_names_for_gates)



